提交 87e8727d 编写于 作者: X xuezhong

fix batch offset bug

上级 3eba53b7
...@@ -317,5 +317,4 @@ def rc_model(hidden_size, vocab, args): ...@@ -317,5 +317,4 @@ def rc_model(hidden_size, vocab, args):
cost.persistable = True cost.persistable = True
feeding_list = ["q_ids", "start_lables", "end_lables", "p_ids", "q_id0"] feeding_list = ["q_ids", "start_lables", "end_lables", "p_ids", "q_id0"]
layers.Print(ms, message='ms', summarize=3)
return cost, start_probs, end_probs, ms, feeding_list return cost, start_probs, end_probs, ms, feeding_list
...@@ -236,11 +236,7 @@ def validation(inference_program, avg_cost, s_probs, e_probs, match, feed_order, ...@@ -236,11 +236,7 @@ def validation(inference_program, avg_cost, s_probs, e_probs, match, feed_order,
total_loss += np.array(val_fetch_outs[0]).sum() total_loss += np.array(val_fetch_outs[0]).sum()
start_probs_m = LodTensor_Array(val_fetch_outs[1]) start_probs_m = LodTensor_Array(val_fetch_outs[1])
end_probs_m = LodTensor_Array(val_fetch_outs[2]) end_probs_m = LodTensor_Array(val_fetch_outs[2])
for data in feed_data:
data_len = [[len(y) for y in x[3]] for x in data]
logger.info(str(data_len))
match_lod = val_fetch_outs[3].lod() match_lod = val_fetch_outs[3].lod()
logger.info(str(match_lod))
count += len(np.array(val_fetch_outs[0])) count += len(np.array(val_fetch_outs[0]))
n_batch_cnt += len(np.array(val_fetch_outs[0])) n_batch_cnt += len(np.array(val_fetch_outs[0]))
...@@ -252,18 +248,18 @@ def validation(inference_program, avg_cost, s_probs, e_probs, match, feed_order, ...@@ -252,18 +248,18 @@ def validation(inference_program, avg_cost, s_probs, e_probs, match, feed_order,
n_batch_loss / n_batch_cnt))) n_batch_loss / n_batch_cnt)))
n_batch_loss = 0.0 n_batch_loss = 0.0
n_batch_cnt = 0 n_batch_cnt = 0
batch_offset = 0
for idx, batch in enumerate(batch_list): for idx, batch in enumerate(batch_list):
#one batch #one batch
batch_size = len(batch['raw_data']) batch_size = len(batch['raw_data'])
batch_range = match_lod[0][idx * batch_size:(idx + 1) * batch_size + batch_range = match_lod[0][batch_offset:batch_offset + batch_size +
1] 1]
batch_lod = [[batch_range[x], batch_range[x + 1]] batch_lod = [[batch_range[x], batch_range[x + 1]]
for x in range(len(batch_range[:-1]))] for x in range(len(batch_range[:-1]))]
start_prob_batch = start_probs_m[idx * batch_size:(idx + 1) * start_prob_batch = start_probs_m[batch_offset:batch_offset +
batch_size] batch_size + 1]
end_prob_batch = end_probs_m[idx * batch_size:(idx + 1) * end_prob_batch = end_probs_m[batch_offset:batch_offset + batch_size
batch_size] + 1]
for sample, start_prob_inst, end_prob_inst, inst_range in zip( for sample, start_prob_inst, end_prob_inst, inst_range in zip(
batch['raw_data'], start_prob_batch, end_prob_batch, batch['raw_data'], start_prob_batch, end_prob_batch,
batch_lod): batch_lod):
...@@ -288,6 +284,7 @@ def validation(inference_program, avg_cost, s_probs, e_probs, match, feed_order, ...@@ -288,6 +284,7 @@ def validation(inference_program, avg_cost, s_probs, e_probs, match, feed_order,
'yesno_answers': [] 'yesno_answers': []
} }
ref_answers.append(ref) ref_answers.append(ref)
batch_offset = batch_offset + batch_size
result_dir = args.result_dir result_dir = args.result_dir
result_prefix = args.result_name result_prefix = args.result_name
...@@ -341,8 +338,9 @@ def train(logger, args): ...@@ -341,8 +338,9 @@ def train(logger, args):
# build model # build model
main_program = fluid.Program() main_program = fluid.Program()
startup_prog = fluid.Program() startup_prog = fluid.Program()
main_program.random_seed = args.random_seed if args.enable_ce:
startup_prog.random_seed = args.random_seed main_program.random_seed = args.random_seed
startup_prog.random_seed = args.random_seed
with fluid.program_guard(main_program, startup_prog): with fluid.program_guard(main_program, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
avg_cost, s_probs, e_probs, match, feed_order = rc_model.rc_model( avg_cost, s_probs, e_probs, match, feed_order = rc_model.rc_model(
...@@ -402,7 +400,10 @@ def train(logger, args): ...@@ -402,7 +400,10 @@ def train(logger, args):
for pass_id in range(1, args.pass_num + 1): for pass_id in range(1, args.pass_num + 1):
pass_start_time = time.time() pass_start_time = time.time()
pad_id = vocab.get_id(vocab.pad_token) pad_id = vocab.get_id(vocab.pad_token)
train_reader = lambda:brc_data.gen_mini_batches('train', args.batch_size, pad_id, shuffle=False) if args.enable_ce:
train_reader = lambda:brc_data.gen_mini_batches('train', args.batch_size, pad_id, shuffle=False)
else:
train_reader = lambda:brc_data.gen_mini_batches('train', args.batch_size, pad_id, shuffle=True)
train_reader = read_multiple(train_reader, dev_count) train_reader = read_multiple(train_reader, dev_count)
log_every_n_batch, n_batch_loss = args.log_interval, 0 log_every_n_batch, n_batch_loss = args.log_interval, 0
total_num, total_loss = 0, 0 total_num, total_loss = 0, 0
...@@ -488,8 +489,6 @@ def evaluate(logger, args): ...@@ -488,8 +489,6 @@ def evaluate(logger, args):
# build model # build model
main_program = fluid.Program() main_program = fluid.Program()
startup_prog = fluid.Program() startup_prog = fluid.Program()
main_program.random_seed = args.random_seed
startup_prog.random_seed = args.random_seed
with fluid.program_guard(main_program, startup_prog): with fluid.program_guard(main_program, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
avg_cost, s_probs, e_probs, match, feed_order = rc_model.rc_model( avg_cost, s_probs, e_probs, match, feed_order = rc_model.rc_model(
...@@ -537,8 +536,6 @@ def predict(logger, args): ...@@ -537,8 +536,6 @@ def predict(logger, args):
# build model # build model
main_program = fluid.Program() main_program = fluid.Program()
startup_prog = fluid.Program() startup_prog = fluid.Program()
main_program.random_seed = args.random_seed
startup_prog.random_seed = args.random_seed
with fluid.program_guard(main_program, startup_prog): with fluid.program_guard(main_program, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
avg_cost, s_probs, e_probs, match, feed_order = rc_model.rc_model( avg_cost, s_probs, e_probs, match, feed_order = rc_model.rc_model(
...@@ -606,8 +603,9 @@ def prepare(logger, args): ...@@ -606,8 +603,9 @@ def prepare(logger, args):
if __name__ == '__main__': if __name__ == '__main__':
args = parse_args() args = parse_args()
random.seed(args.random_seed) if args.enable_ce:
np.random.seed(args.random_seed) random.seed(args.random_seed)
np.random.seed(args.random_seed)
logger = logging.getLogger("brc") logger = logging.getLogger("brc")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册