未验证 提交 48342752 编写于 作者: X xuezhong 提交者: GitHub

Merge pull request #1380 from xuezhong/machine_reading_comprehesion

Machine reading comprehesion
...@@ -132,61 +132,44 @@ class BRCDataset(object): ...@@ -132,61 +132,44 @@ class BRCDataset(object):
'passage_token_ids': [], 'passage_token_ids': [],
'passage_length': [], 'passage_length': [],
'start_id': [], 'start_id': [],
'end_id': [] 'end_id': [],
'passage_num': []
} }
max_passage_num = max( max_passage_num = max(
[len(sample['passages']) for sample in batch_data['raw_data']]) [len(sample['passages']) for sample in batch_data['raw_data']])
#max_passage_num = min(self.max_p_num, max_passage_num) max_passage_num = min(self.max_p_num, max_passage_num)
max_passage_num = self.max_p_num
for sidx, sample in enumerate(batch_data['raw_data']): for sidx, sample in enumerate(batch_data['raw_data']):
count = 0
for pidx in range(max_passage_num): for pidx in range(max_passage_num):
if pidx < len(sample['passages']): if pidx < len(sample['passages']):
count += 1
batch_data['question_token_ids'].append(sample[ batch_data['question_token_ids'].append(sample[
'question_token_ids']) 'question_token_ids'][0:self.max_q_len])
batch_data['question_length'].append( batch_data['question_length'].append(
len(sample['question_token_ids'])) min(len(sample['question_token_ids']), self.max_q_len))
passage_token_ids = sample['passages'][pidx][ passage_token_ids = sample['passages'][pidx][
'passage_token_ids'] 'passage_token_ids'][0:self.max_p_len]
batch_data['passage_token_ids'].append(passage_token_ids) batch_data['passage_token_ids'].append(passage_token_ids)
batch_data['passage_length'].append( batch_data['passage_length'].append(
min(len(passage_token_ids), self.max_p_len)) min(len(passage_token_ids), self.max_p_len))
else: # record the start passage index of current doc
batch_data['question_token_ids'].append([]) passade_idx_offset = sum(batch_data['passage_num'])
batch_data['question_length'].append(0) batch_data['passage_num'].append(count)
batch_data['passage_token_ids'].append([]) gold_passage_offset = 0
batch_data['passage_length'].append(0)
batch_data, padded_p_len, padded_q_len = self._dynamic_padding(
batch_data, pad_id)
for sample in batch_data['raw_data']:
if 'answer_passages' in sample and len(sample['answer_passages']): if 'answer_passages' in sample and len(sample['answer_passages']):
gold_passage_offset = padded_p_len * sample['answer_passages'][ for i in range(sample['answer_passages'][0]):
0] gold_passage_offset += len(batch_data['passage_token_ids'][
batch_data['start_id'].append(gold_passage_offset + sample[ passade_idx_offset + i])
'answer_spans'][0][0]) start_id = min(sample['answer_spans'][0][0], self.max_p_len)
batch_data['end_id'].append(gold_passage_offset + sample[ end_id = min(sample['answer_spans'][0][1], self.max_p_len)
'answer_spans'][0][1]) batch_data['start_id'].append(gold_passage_offset + start_id)
batch_data['end_id'].append(gold_passage_offset + end_id)
else: else:
# fake span for some samples, only valid for testing # fake span for some samples, only valid for testing
batch_data['start_id'].append(0) batch_data['start_id'].append(0)
batch_data['end_id'].append(0) batch_data['end_id'].append(0)
return batch_data return batch_data
def _dynamic_padding(self, batch_data, pad_id):
"""
Dynamically pads the batch_data with pad_id
"""
pad_p_len = min(self.max_p_len, max(batch_data['passage_length']))
pad_q_len = min(self.max_q_len, max(batch_data['question_length']))
batch_data['passage_token_ids'] = [
(ids + [pad_id] * (pad_p_len - len(ids)))[:pad_p_len]
for ids in batch_data['passage_token_ids']
]
batch_data['question_token_ids'] = [
(ids + [pad_id] * (pad_q_len - len(ids)))[:pad_q_len]
for ids in batch_data['question_token_ids']
]
return batch_data, pad_p_len, pad_q_len
def word_iter(self, set_name=None): def word_iter(self, set_name=None):
""" """
Iterates over all the words in the dataset Iterates over all the words in the dataset
......
...@@ -68,16 +68,23 @@ def bi_lstm_encoder(input_seq, gate_size, para_name, args): ...@@ -68,16 +68,23 @@ def bi_lstm_encoder(input_seq, gate_size, para_name, args):
return encoder_out return encoder_out
def encoder(input_name, para_name, shape, hidden_size, args): def get_data(input_name, lod_level, args):
input_ids = layers.data( input_ids = layers.data(
name=input_name, shape=[1], dtype='int64', lod_level=1) name=input_name, shape=[1], dtype='int64', lod_level=lod_level)
return input_ids
def embedding(input_ids, shape, args):
input_embedding = layers.embedding( input_embedding = layers.embedding(
input=input_ids, input=input_ids,
size=shape, size=shape,
dtype='float32', dtype='float32',
is_sparse=True, is_sparse=True,
param_attr=fluid.ParamAttr(name='embedding_para')) param_attr=fluid.ParamAttr(name='embedding_para'))
return input_embedding
def encoder(input_embedding, para_name, hidden_size, args):
encoder_out = bi_lstm_encoder( encoder_out = bi_lstm_encoder(
input_seq=input_embedding, input_seq=input_embedding,
gate_size=hidden_size, gate_size=hidden_size,
...@@ -259,40 +266,41 @@ def fusion(g, args): ...@@ -259,40 +266,41 @@ def fusion(g, args):
def rc_model(hidden_size, vocab, args): def rc_model(hidden_size, vocab, args):
emb_shape = [vocab.size(), vocab.embed_dim] emb_shape = [vocab.size(), vocab.embed_dim]
start_labels = layers.data(
name="start_lables", shape=[1], dtype='float32', lod_level=1)
end_labels = layers.data(
name="end_lables", shape=[1], dtype='float32', lod_level=1)
# stage 1:encode # stage 1:encode
p_ids_names = [] q_id0 = get_data('q_id0', 1, args)
q_ids_names = []
ms = [] q_ids = get_data('q_ids', 2, args)
gs = [] p_ids_name = 'p_ids'
qs = []
for i in range(args.doc_num): p_ids = get_data('p_ids', 2, args)
p_ids_name = "pids_%d" % i p_embs = embedding(p_ids, emb_shape, args)
p_ids_names.append(p_ids_name) q_embs = embedding(q_ids, emb_shape, args)
p_enc_i = encoder(p_ids_name, 'p_enc', emb_shape, hidden_size, args) drnn = layers.DynamicRNN()
with drnn.block():
q_ids_name = "qids_%d" % i p_emb = drnn.step_input(p_embs)
q_ids_names.append(q_ids_name) q_emb = drnn.step_input(q_embs)
q_enc_i = encoder(q_ids_name, 'q_enc', emb_shape, hidden_size, args)
p_enc = encoder(p_emb, 'p_enc', hidden_size, args)
q_enc = encoder(q_emb, 'q_enc', hidden_size, args)
# stage 2:match # stage 2:match
g_i = attn_flow(q_enc_i, p_enc_i, p_ids_name, args) g_i = attn_flow(q_enc, p_enc, p_ids_name, args)
# stage 3:fusion # stage 3:fusion
m_i = fusion(g_i, args) m_i = fusion(g_i, args)
ms.append(m_i) drnn.output(m_i, q_enc)
gs.append(g_i)
qs.append(q_enc_i) ms, q_encs = drnn()
m = layers.sequence_concat(input=ms) p_vec = layers.lod_reset(x=ms, y=start_labels)
g = layers.sequence_concat(input=gs) q_vec = layers.lod_reset(x=q_encs, y=q_id0)
q_vec = layers.sequence_concat(input=qs)
# stage 4:decode # stage 4:decode
start_probs, end_probs = point_network_decoder( start_probs, end_probs = point_network_decoder(
p_vec=m, q_vec=q_vec, hidden_size=hidden_size, args=args) p_vec=p_vec, q_vec=q_vec, hidden_size=hidden_size, args=args)
start_labels = layers.data(
name="start_lables", shape=[1], dtype='float32', lod_level=1)
end_labels = layers.data(
name="end_lables", shape=[1], dtype='float32', lod_level=1)
cost0 = layers.sequence_pool( cost0 = layers.sequence_pool(
layers.cross_entropy( layers.cross_entropy(
...@@ -308,5 +316,5 @@ def rc_model(hidden_size, vocab, args): ...@@ -308,5 +316,5 @@ def rc_model(hidden_size, vocab, args):
cost = cost0 + cost1 cost = cost0 + cost1
cost.persistable = True cost.persistable = True
feeding_list = q_ids_names + ["start_lables", "end_lables"] + p_ids_names feeding_list = ["q_ids", "start_lables", "end_lables", "p_ids", "q_id0"]
return cost, start_probs, end_probs, feeding_list return cost, start_probs, end_probs, ms, feeding_list
...@@ -46,22 +46,32 @@ from vocab import Vocab ...@@ -46,22 +46,32 @@ from vocab import Vocab
def prepare_batch_input(insts, args): def prepare_batch_input(insts, args):
doc_num = args.doc_num
batch_size = len(insts['raw_data']) batch_size = len(insts['raw_data'])
inst_num = len(insts['passage_num'])
if batch_size != inst_num:
print("data error %d, %d" % (batch_size, inst_num))
return None
new_insts = [] new_insts = []
passage_idx = 0
for i in range(batch_size): for i in range(batch_size):
p_len = 0
p_id = [] p_id = []
q_id = []
p_ids = [] p_ids = []
q_ids = [] q_ids = []
p_len = 0 q_id = []
for j in range(i * doc_num, (i + 1) * doc_num): p_id_r = []
p_ids.append(insts['passage_token_ids'][j]) p_ids_r = []
p_id = p_id + insts['passage_token_ids'][j] q_ids_r = []
q_ids.append(insts['question_token_ids'][j]) q_id_r = []
q_id = q_id + insts['question_token_ids'][j]
for j in range(insts['passage_num'][i]):
p_ids.append(insts['passage_token_ids'][passage_idx + j])
p_id = p_id + insts['passage_token_ids'][passage_idx + j]
q_ids.append(insts['question_token_ids'][passage_idx + j])
q_id = q_id + insts['question_token_ids'][passage_idx + j]
passage_idx += insts['passage_num'][i]
p_len = len(p_id) p_len = len(p_id)
def _get_label(idx, ref_len): def _get_label(idx, ref_len):
...@@ -72,11 +82,46 @@ def prepare_batch_input(insts, args): ...@@ -72,11 +82,46 @@ def prepare_batch_input(insts, args):
start_label = _get_label(insts['start_id'][i], p_len) start_label = _get_label(insts['start_id'][i], p_len)
end_label = _get_label(insts['end_id'][i], p_len) end_label = _get_label(insts['end_id'][i], p_len)
new_inst = q_ids + [start_label, end_label] + p_ids new_inst = [q_ids, start_label, end_label, p_ids, q_id]
new_insts.append(new_inst) new_insts.append(new_inst)
return new_insts return new_insts
def batch_reader(batch_list, args):
res = []
for batch in batch_list:
res.append(prepare_batch_input(batch, args))
return res
def read_multiple(reader, count, clip_last=True):
"""
Stack data from reader for multi-devices.
"""
def __impl__():
res = []
for item in reader():
res.append(item)
if len(res) == count:
yield res
res = []
if len(res) == count:
yield res
elif not clip_last:
data = []
for item in res:
data += item
if len(data) > count:
inst_num_per_part = len(data) // count
yield [
data[inst_num_per_part * i:inst_num_per_part * (i + 1)]
for i in range(count)
]
return __impl__
def LodTensor_Array(lod_tensor): def LodTensor_Array(lod_tensor):
lod = lod_tensor.lod() lod = lod_tensor.lod()
array = np.array(lod_tensor) array = np.array(lod_tensor)
...@@ -103,7 +148,7 @@ def print_para(train_prog, train_exe, logger, args): ...@@ -103,7 +148,7 @@ def print_para(train_prog, train_exe, logger, args):
logger.info("total param num: {0}".format(num_sum)) logger.info("total param num: {0}".format(num_sum))
def find_best_answer_for_passage(start_probs, end_probs, passage_len, args): def find_best_answer_for_passage(start_probs, end_probs, passage_len):
""" """
Finds the best answer with the maximum start_prob * end_prob from a single passage Finds the best answer with the maximum start_prob * end_prob from a single passage
""" """
...@@ -125,7 +170,7 @@ def find_best_answer_for_passage(start_probs, end_probs, passage_len, args): ...@@ -125,7 +170,7 @@ def find_best_answer_for_passage(start_probs, end_probs, passage_len, args):
return (best_start, best_end), max_prob return (best_start, best_end), max_prob
def find_best_answer(sample, start_prob, end_prob, padded_p_len, args): def find_best_answer_for_inst(sample, start_prob, end_prob, inst_lod):
""" """
Finds the best answer for a sample given start_prob and end_prob for each position. Finds the best answer for a sample given start_prob and end_prob for each position.
This will call find_best_answer_for_passage because there are multiple passages in a sample This will call find_best_answer_for_passage because there are multiple passages in a sample
...@@ -134,11 +179,16 @@ def find_best_answer(sample, start_prob, end_prob, padded_p_len, args): ...@@ -134,11 +179,16 @@ def find_best_answer(sample, start_prob, end_prob, padded_p_len, args):
for p_idx, passage in enumerate(sample['passages']): for p_idx, passage in enumerate(sample['passages']):
if p_idx >= args.max_p_num: if p_idx >= args.max_p_num:
continue continue
if len(start_prob) != len(end_prob):
logger.info('error: {}'.format(sample['question']))
continue
passage_start = inst_lod[p_idx] - inst_lod[0]
passage_end = inst_lod[p_idx + 1] - inst_lod[0]
passage_len = passage_end - passage_start
passage_len = min(args.max_p_len, len(passage['passage_tokens'])) passage_len = min(args.max_p_len, len(passage['passage_tokens']))
answer_span, score = find_best_answer_for_passage( answer_span, score = find_best_answer_for_passage(
start_prob[p_idx * padded_p_len:(p_idx + 1) * padded_p_len], start_prob[passage_start:passage_end],
end_prob[p_idx * padded_p_len:(p_idx + 1) * padded_p_len], end_prob[passage_start:passage_end], passage_len)
passage_len, args)
if score > best_score: if score > best_score:
best_score = score best_score = score
best_p_idx = p_idx best_p_idx = p_idx
...@@ -148,11 +198,11 @@ def find_best_answer(sample, start_prob, end_prob, padded_p_len, args): ...@@ -148,11 +198,11 @@ def find_best_answer(sample, start_prob, end_prob, padded_p_len, args):
else: else:
best_answer = ''.join(sample['passages'][best_p_idx]['passage_tokens'][ best_answer = ''.join(sample['passages'][best_p_idx]['passage_tokens'][
best_span[0]:best_span[1] + 1]) best_span[0]:best_span[1] + 1])
return best_answer return best_answer, best_span
def validation(inference_program, avg_cost, s_probs, e_probs, feed_order, place, def validation(inference_program, avg_cost, s_probs, e_probs, match, feed_order,
vocab, brc_data, logger, args): place, dev_count, vocab, brc_data, logger, args):
""" """
""" """
...@@ -165,6 +215,8 @@ def validation(inference_program, avg_cost, s_probs, e_probs, feed_order, place, ...@@ -165,6 +215,8 @@ def validation(inference_program, avg_cost, s_probs, e_probs, feed_order, place,
# Use test set as validation each pass # Use test set as validation each pass
total_loss = 0.0 total_loss = 0.0
count = 0 count = 0
n_batch_cnt = 0
n_batch_loss = 0.0
pred_answers, ref_answers = [], [] pred_answers, ref_answers = [], []
val_feed_list = [ val_feed_list = [
inference_program.global_block().var(var_name) inference_program.global_block().var(var_name)
...@@ -172,55 +224,80 @@ def validation(inference_program, avg_cost, s_probs, e_probs, feed_order, place, ...@@ -172,55 +224,80 @@ def validation(inference_program, avg_cost, s_probs, e_probs, feed_order, place,
] ]
val_feeder = fluid.DataFeeder(val_feed_list, place) val_feeder = fluid.DataFeeder(val_feed_list, place)
pad_id = vocab.get_id(vocab.pad_token) pad_id = vocab.get_id(vocab.pad_token)
dev_batches = brc_data.gen_mini_batches( dev_reader = lambda:brc_data.gen_mini_batches('dev', args.batch_size, pad_id, shuffle=False)
'dev', args.batch_size, pad_id, shuffle=False) dev_reader = read_multiple(dev_reader, dev_count)
for batch_id, batch in enumerate(dev_batches, 1): for batch_id, batch_list in enumerate(dev_reader(), 1):
feed_data = prepare_batch_input(batch, args) feed_data = batch_reader(batch_list, args)
val_fetch_outs = parallel_executor.run( val_fetch_outs = parallel_executor.run(
feed=val_feeder.feed(feed_data), feed=list(val_feeder.feed_parallel(feed_data, dev_count)),
fetch_list=[avg_cost.name, s_probs.name, e_probs.name], fetch_list=[avg_cost.name, s_probs.name, e_probs.name, match.name],
return_numpy=False) return_numpy=False)
total_loss += np.array(val_fetch_outs[0]).sum()
total_loss += np.array(val_fetch_outs[0])[0] start_probs_m = LodTensor_Array(val_fetch_outs[1])
end_probs_m = LodTensor_Array(val_fetch_outs[2])
start_probs = LodTensor_Array(val_fetch_outs[1]) match_lod = val_fetch_outs[3].lod()
end_probs = LodTensor_Array(val_fetch_outs[2]) count += len(np.array(val_fetch_outs[0]))
count += len(batch['raw_data'])
n_batch_cnt += len(np.array(val_fetch_outs[0]))
padded_p_len = len(batch['passage_token_ids'][0]) n_batch_loss += np.array(val_fetch_outs[0]).sum()
for sample, start_prob, end_prob in zip(batch['raw_data'], start_probs, log_every_n_batch = args.log_interval
end_probs): if log_every_n_batch > 0 and batch_id % log_every_n_batch == 0:
logger.info('Average dev loss from batch {} to {} is {}'.format(
best_answer = find_best_answer(sample, start_prob, end_prob, batch_id - log_every_n_batch + 1, batch_id, "%.10f" % (
padded_p_len, args) n_batch_loss / n_batch_cnt)))
pred_answers.append({ n_batch_loss = 0.0
'question_id': sample['question_id'], n_batch_cnt = 0
'question_type': sample['question_type'],
'answers': [best_answer], for idx, batch in enumerate(batch_list):
'entity_answers': [[]], #one batch
'yesno_answers': [] batch_size = len(batch['raw_data'])
}) batch_range = match_lod[0][idx * batch_size:(idx + 1) * batch_size +
if 'answers' in sample: 1]
ref_answers.append({ batch_lod = [[batch_range[x], batch_range[x + 1]]
for x in range(len(batch_range[:-1]))]
start_prob_batch = start_probs_m[idx * batch_size:(idx + 1) *
batch_size]
end_prob_batch = end_probs_m[idx * batch_size:(idx + 1) *
batch_size]
for sample, start_prob_inst, end_prob_inst, inst_range in zip(
batch['raw_data'], start_prob_batch, end_prob_batch,
batch_lod):
#one instance
inst_lod = match_lod[1][inst_range[0]:inst_range[1] + 1]
best_answer, best_span = find_best_answer_for_inst(
sample, start_prob_inst, end_prob_inst, inst_lod)
pred = {
'question_id': sample['question_id'], 'question_id': sample['question_id'],
'question_type': sample['question_type'], 'question_type': sample['question_type'],
'answers': sample['answers'], 'answers': [best_answer],
'entity_answers': [[]], 'entity_answers': [[]],
'yesno_answers': [] 'yesno_answers': [best_span]
}) }
if args.result_dir is not None and args.result_name is not None: pred_answers.append(pred)
if 'answers' in sample:
ref = {
'question_id': sample['question_id'],
'question_type': sample['question_type'],
'answers': sample['answers'],
'entity_answers': [[]],
'yesno_answers': []
}
ref_answers.append(ref)
result_dir = args.result_dir
result_prefix = args.result_name
if result_dir is not None and result_prefix is not None:
if not os.path.exists(args.result_dir): if not os.path.exists(args.result_dir):
os.makedirs(args.result_dir) os.makedirs(args.result_dir)
result_file = os.path.join(args.result_dir, args.result_name + '.json') result_file = os.path.join(result_dir, result_prefix + 'json')
with open(result_file, 'w') as fout: with open(result_file, 'w') as fout:
for pred_answer in pred_answers: for pred_answer in pred_answers:
fout.write(json.dumps(pred_answer, ensure_ascii=False) + '\n') fout.write(json.dumps(pred_answer, ensure_ascii=False) + '\n')
logger.info('Saving {} results to {}'.format(args.result_name, logger.info('Saving {} results to {}'.format(result_prefix,
result_file)) result_file))
ave_loss = 1.0 * total_loss / count ave_loss = 1.0 * total_loss / count
# compute the bleu and rouge scores if reference answers is provided # compute the bleu and rouge scores if reference answers is provided
if len(ref_answers) > 0: if len(ref_answers) > 0:
pred_dict, ref_dict = {}, {} pred_dict, ref_dict = {}, {}
...@@ -250,6 +327,13 @@ def train(logger, args): ...@@ -250,6 +327,13 @@ def train(logger, args):
brc_data.convert_to_ids(vocab) brc_data.convert_to_ids(vocab)
logger.info('Initialize the model...') logger.info('Initialize the model...')
if not args.use_gpu:
place = fluid.CPUPlace()
dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
else:
place = fluid.CUDAPlace(0)
dev_count = fluid.core.get_cuda_device_count()
# build model # build model
main_program = fluid.Program() main_program = fluid.Program()
startup_prog = fluid.Program() startup_prog = fluid.Program()
...@@ -257,7 +341,7 @@ def train(logger, args): ...@@ -257,7 +341,7 @@ def train(logger, args):
startup_prog.random_seed = args.random_seed startup_prog.random_seed = args.random_seed
with fluid.program_guard(main_program, startup_prog): with fluid.program_guard(main_program, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
avg_cost, s_probs, e_probs, feed_order = rc_model.rc_model( avg_cost, s_probs, e_probs, match, feed_order = rc_model.rc_model(
args.hidden_size, vocab, args) args.hidden_size, vocab, args)
# clone from default main program and use it as the validation program # clone from default main program and use it as the validation program
inference_program = main_program.clone(for_test=True) inference_program = main_program.clone(for_test=True)
...@@ -314,20 +398,21 @@ def train(logger, args): ...@@ -314,20 +398,21 @@ def train(logger, args):
for pass_id in range(1, args.pass_num + 1): for pass_id in range(1, args.pass_num + 1):
pass_start_time = time.time() pass_start_time = time.time()
pad_id = vocab.get_id(vocab.pad_token) pad_id = vocab.get_id(vocab.pad_token)
train_batches = brc_data.gen_mini_batches( train_reader = lambda:brc_data.gen_mini_batches('train', args.batch_size, pad_id, shuffle=False)
'train', args.batch_size, pad_id, shuffle=True) train_reader = read_multiple(train_reader, dev_count)
log_every_n_batch, n_batch_loss = args.log_interval, 0 log_every_n_batch, n_batch_loss = args.log_interval, 0
total_num, total_loss = 0, 0 total_num, total_loss = 0, 0
for batch_id, batch in enumerate(train_batches, 1): for batch_id, batch_list in enumerate(train_reader(), 1):
input_data_dict = prepare_batch_input(batch, args) feed_data = batch_reader(batch_list, args)
fetch_outs = parallel_executor.run( fetch_outs = parallel_executor.run(
feed=feeder.feed(input_data_dict), feed=list(feeder.feed_parallel(feed_data, dev_count)),
fetch_list=[avg_cost.name], fetch_list=[avg_cost.name],
return_numpy=False) return_numpy=False)
cost_train = np.array(fetch_outs[0])[0] cost_train = np.array(fetch_outs[0]).mean()
total_num += len(batch['raw_data']) total_num += args.batch_size * dev_count
n_batch_loss += cost_train n_batch_loss += cost_train
total_loss += cost_train * len(batch['raw_data']) total_loss += cost_train * args.batch_size * dev_count
if log_every_n_batch > 0 and batch_id % log_every_n_batch == 0: if log_every_n_batch > 0 and batch_id % log_every_n_batch == 0:
print_para(main_program, parallel_executor, logger, print_para(main_program, parallel_executor, logger,
args) args)
...@@ -337,19 +422,23 @@ def train(logger, args): ...@@ -337,19 +422,23 @@ def train(logger, args):
"%.10f" % (n_batch_loss / log_every_n_batch))) "%.10f" % (n_batch_loss / log_every_n_batch)))
n_batch_loss = 0 n_batch_loss = 0
if args.dev_interval > 0 and batch_id % args.dev_interval == 0: if args.dev_interval > 0 and batch_id % args.dev_interval == 0:
eval_loss, bleu_rouge = validation( if brc_data.dev_set is not None:
inference_program, avg_cost, s_probs, e_probs, eval_loss, bleu_rouge = validation(
feed_order, place, vocab, brc_data, logger, args) inference_program, avg_cost, s_probs, e_probs,
logger.info('Dev eval loss {}'.format(eval_loss)) match, feed_order, place, dev_count, vocab,
logger.info('Dev eval result: {}'.format(bleu_rouge)) brc_data, logger, args)
logger.info('Dev eval loss {}'.format(eval_loss))
logger.info('Dev eval result: {}'.format(
bleu_rouge))
pass_end_time = time.time() pass_end_time = time.time()
logger.info('Evaluating the model after epoch {}'.format( logger.info('Evaluating the model after epoch {}'.format(
pass_id)) pass_id))
if brc_data.dev_set is not None: if brc_data.dev_set is not None:
eval_loss, bleu_rouge = validation( eval_loss, bleu_rouge = validation(
inference_program, avg_cost, s_probs, e_probs, inference_program, avg_cost, s_probs, e_probs, match,
feed_order, place, vocab, brc_data, logger, args) feed_order, place, dev_count, vocab, brc_data, logger,
args)
logger.info('Dev eval loss {}'.format(eval_loss)) logger.info('Dev eval loss {}'.format(eval_loss))
logger.info('Dev eval result: {}'.format(bleu_rouge)) logger.info('Dev eval result: {}'.format(bleu_rouge))
else: else:
...@@ -389,10 +478,17 @@ def evaluate(logger, args): ...@@ -389,10 +478,17 @@ def evaluate(logger, args):
startup_prog.random_seed = args.random_seed startup_prog.random_seed = args.random_seed
with fluid.program_guard(main_program, startup_prog): with fluid.program_guard(main_program, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
avg_cost, s_probs, e_probs, feed_order = rc_model.rc_model( avg_cost, s_probs, e_probs, match, feed_order = rc_model.rc_model(
args.hidden_size, vocab, args) args.hidden_size, vocab, args)
# initialize parameters # initialize parameters
place = core.CUDAPlace(0) if args.use_gpu else core.CPUPlace() if not args.use_gpu:
place = fluid.CPUPlace()
dev_count = int(
os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
else:
place = fluid.CUDAPlace(0)
dev_count = fluid.core.get_cuda_device_count()
exe = Executor(place) exe = Executor(place)
if args.load_dir: if args.load_dir:
logger.info('load from {}'.format(args.load_dir)) logger.info('load from {}'.format(args.load_dir))
...@@ -402,17 +498,10 @@ def evaluate(logger, args): ...@@ -402,17 +498,10 @@ def evaluate(logger, args):
logger.error('No model file to load ...') logger.error('No model file to load ...')
return return
# prepare data
feed_list = [
main_program.global_block().var(var_name)
for var_name in feed_order
]
feeder = fluid.DataFeeder(feed_list, place)
inference_program = main_program.clone(for_test=True) inference_program = main_program.clone(for_test=True)
eval_loss, bleu_rouge = validation( eval_loss, bleu_rouge = validation(
inference_program, avg_cost, s_probs, e_probs, feed_order, inference_program, avg_cost, s_probs, e_probs, feed_order,
place, vocab, brc_data, logger, args) place, dev_count, vocab, brc_data, logger, args)
logger.info('Dev eval loss {}'.format(eval_loss)) logger.info('Dev eval loss {}'.format(eval_loss))
logger.info('Dev eval result: {}'.format(bleu_rouge)) logger.info('Dev eval result: {}'.format(bleu_rouge))
logger.info('Predicted answers are saved to {}'.format( logger.info('Predicted answers are saved to {}'.format(
...@@ -438,10 +527,17 @@ def predict(logger, args): ...@@ -438,10 +527,17 @@ def predict(logger, args):
startup_prog.random_seed = args.random_seed startup_prog.random_seed = args.random_seed
with fluid.program_guard(main_program, startup_prog): with fluid.program_guard(main_program, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
avg_cost, s_probs, e_probs, feed_order = rc_model.rc_model( avg_cost, s_probs, e_probs, match, feed_order = rc_model.rc_model(
args.hidden_size, vocab, args) args.hidden_size, vocab, args)
# initialize parameters # initialize parameters
place = core.CUDAPlace(0) if args.use_gpu else core.CPUPlace() if not args.use_gpu:
place = fluid.CPUPlace()
dev_count = int(
os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
else:
place = fluid.CUDAPlace(0)
dev_count = fluid.core.get_cuda_device_count()
exe = Executor(place) exe = Executor(place)
if args.load_dir: if args.load_dir:
logger.info('load from {}'.format(args.load_dir)) logger.info('load from {}'.format(args.load_dir))
...@@ -451,17 +547,10 @@ def predict(logger, args): ...@@ -451,17 +547,10 @@ def predict(logger, args):
logger.error('No model file to load ...') logger.error('No model file to load ...')
return return
# prepare data
feed_list = [
main_program.global_block().var(var_name)
for var_name in feed_order
]
feeder = fluid.DataFeeder(feed_list, place)
inference_program = main_program.clone(for_test=True) inference_program = main_program.clone(for_test=True)
eval_loss, bleu_rouge = validation( eval_loss, bleu_rouge = validation(
inference_program, avg_cost, s_probs, e_probs, feed_order, inference_program, avg_cost, s_probs, e_probs, match,
place, vocab, brc_data, logger, args) feed_order, place, dev_count, vocab, brc_data, logger, args)
def prepare(logger, args): def prepare(logger, args):
......
export CUDA_VISIBLE_DEVICES=1 export CUDA_VISIBLE_DEVICES=0
python run.py \ python run.py \
--trainset 'data/preprocessed/trainset/search.train.json' \ --trainset 'data/preprocessed/trainset/search.train.json' \
'data/preprocessed/trainset/zhidao.train.json' \ 'data/preprocessed/trainset/zhidao.train.json' \
...@@ -11,11 +11,12 @@ python run.py \ ...@@ -11,11 +11,12 @@ python run.py \
--save_dir ./models \ --save_dir ./models \
--pass_num 10 \ --pass_num 10 \
--learning_rate 0.001 \ --learning_rate 0.001 \
--batch_size 8 \ --batch_size 32 \
--embed_size 300 \ --embed_size 300 \
--hidden_size 150 \ --hidden_size 150 \
--max_p_num 5 \ --max_p_num 5 \
--max_p_len 500 \ --max_p_len 500 \
--max_q_len 60 \ --max_q_len 60 \
--max_a_len 200 \ --max_a_len 200 \
--weight_decay 0.0 \
--drop_rate 0.2 $@\ --drop_rate 0.2 $@\
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册