提交 87e8727d 编写于 作者: X xuezhong

fix batch offset bug

上级 3eba53b7
......@@ -317,5 +317,4 @@ def rc_model(hidden_size, vocab, args):
cost.persistable = True
feeding_list = ["q_ids", "start_lables", "end_lables", "p_ids", "q_id0"]
layers.Print(ms, message='ms', summarize=3)
return cost, start_probs, end_probs, ms, feeding_list
......@@ -236,11 +236,7 @@ def validation(inference_program, avg_cost, s_probs, e_probs, match, feed_order,
total_loss += np.array(val_fetch_outs[0]).sum()
start_probs_m = LodTensor_Array(val_fetch_outs[1])
end_probs_m = LodTensor_Array(val_fetch_outs[2])
for data in feed_data:
data_len = [[len(y) for y in x[3]] for x in data]
logger.info(str(data_len))
match_lod = val_fetch_outs[3].lod()
logger.info(str(match_lod))
count += len(np.array(val_fetch_outs[0]))
n_batch_cnt += len(np.array(val_fetch_outs[0]))
......@@ -252,18 +248,18 @@ def validation(inference_program, avg_cost, s_probs, e_probs, match, feed_order,
n_batch_loss / n_batch_cnt)))
n_batch_loss = 0.0
n_batch_cnt = 0
batch_offset = 0
for idx, batch in enumerate(batch_list):
#one batch
batch_size = len(batch['raw_data'])
batch_range = match_lod[0][idx * batch_size:(idx + 1) * batch_size +
batch_range = match_lod[0][batch_offset:batch_offset + batch_size +
1]
batch_lod = [[batch_range[x], batch_range[x + 1]]
for x in range(len(batch_range[:-1]))]
start_prob_batch = start_probs_m[idx * batch_size:(idx + 1) *
batch_size]
end_prob_batch = end_probs_m[idx * batch_size:(idx + 1) *
batch_size]
start_prob_batch = start_probs_m[batch_offset:batch_offset +
batch_size + 1]
end_prob_batch = end_probs_m[batch_offset:batch_offset + batch_size
+ 1]
for sample, start_prob_inst, end_prob_inst, inst_range in zip(
batch['raw_data'], start_prob_batch, end_prob_batch,
batch_lod):
......@@ -288,6 +284,7 @@ def validation(inference_program, avg_cost, s_probs, e_probs, match, feed_order,
'yesno_answers': []
}
ref_answers.append(ref)
batch_offset = batch_offset + batch_size
result_dir = args.result_dir
result_prefix = args.result_name
......@@ -341,8 +338,9 @@ def train(logger, args):
# build model
main_program = fluid.Program()
startup_prog = fluid.Program()
main_program.random_seed = args.random_seed
startup_prog.random_seed = args.random_seed
if args.enable_ce:
main_program.random_seed = args.random_seed
startup_prog.random_seed = args.random_seed
with fluid.program_guard(main_program, startup_prog):
with fluid.unique_name.guard():
avg_cost, s_probs, e_probs, match, feed_order = rc_model.rc_model(
......@@ -402,7 +400,10 @@ def train(logger, args):
for pass_id in range(1, args.pass_num + 1):
pass_start_time = time.time()
pad_id = vocab.get_id(vocab.pad_token)
train_reader = lambda:brc_data.gen_mini_batches('train', args.batch_size, pad_id, shuffle=False)
if args.enable_ce:
train_reader = lambda:brc_data.gen_mini_batches('train', args.batch_size, pad_id, shuffle=False)
else:
train_reader = lambda:brc_data.gen_mini_batches('train', args.batch_size, pad_id, shuffle=True)
train_reader = read_multiple(train_reader, dev_count)
log_every_n_batch, n_batch_loss = args.log_interval, 0
total_num, total_loss = 0, 0
......@@ -488,8 +489,6 @@ def evaluate(logger, args):
# build model
main_program = fluid.Program()
startup_prog = fluid.Program()
main_program.random_seed = args.random_seed
startup_prog.random_seed = args.random_seed
with fluid.program_guard(main_program, startup_prog):
with fluid.unique_name.guard():
avg_cost, s_probs, e_probs, match, feed_order = rc_model.rc_model(
......@@ -537,8 +536,6 @@ def predict(logger, args):
# build model
main_program = fluid.Program()
startup_prog = fluid.Program()
main_program.random_seed = args.random_seed
startup_prog.random_seed = args.random_seed
with fluid.program_guard(main_program, startup_prog):
with fluid.unique_name.guard():
avg_cost, s_probs, e_probs, match, feed_order = rc_model.rc_model(
......@@ -606,8 +603,9 @@ def prepare(logger, args):
if __name__ == '__main__':
args = parse_args()
random.seed(args.random_seed)
np.random.seed(args.random_seed)
if args.enable_ce:
random.seed(args.random_seed)
np.random.seed(args.random_seed)
logger = logging.getLogger("brc")
logger.setLevel(logging.INFO)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册