提交 44ad16aa 编写于 作者: G gongweibao

fix

上级 8a14a4ce
......@@ -192,35 +192,26 @@ def main():
beta1=TrainTaskConfig.beta1,
beta2=TrainTaskConfig.beta2,
epsilon=TrainTaskConfig.eps)
optimizer.minimize(avg_cost if TrainTaskConfig.use_avg_cost else sum_cost)
optimize_ops, params_grads = optimizer.minimize(avg_cost if TrainTaskConfig.use_avg_cost else sum_cost)
train_data = paddle.batch(
paddle.reader.shuffle(
nist_data_provider.train("data", ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size),
buf_size=100000),
batch_size=TrainTaskConfig.batch_size)
# Program to do validation.
inference_program = fluid.default_main_program().clone()
with fluid.program_guard(inference_program):
inference_program = fluid.io.get_inference_program([avg_cost])
val_data = paddle.batch(
nist_data_provider.train("data", ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size),
batch_size=TrainTaskConfig.batch_size)'''
def test(exe):
test_total_cost = 0
test_total_token = 0
for batch_id, data in enumerate(val_data()):
for batch_id, data in enumerate(test_reader()):
data_input = prepare_batch_input(
data, encoder_input_data_names + decoder_input_data_names[:-1] +
label_data_names, ModelHyperParams.eos_idx,
ModelHyperParams.eos_idx, ModelHyperParams.n_head,
ModelHyperParams.d_model)
test_sum_cost, test_token_num = exe.run(
test_program,
inference_program,
feed=data_input,
fetch_list=[sum_cost, token_num],
use_program_cache=True)
......@@ -230,50 +221,52 @@ def main():
test_ppl = np.exp([min(test_avg_cost, 100)])
return test_avg_cost, test_ppl
# Initialize the parameters.
exe.run(fluid.framework.default_startup_program())
for pos_enc_param_name in pos_enc_param_names:
pos_enc_param = fluid.global_scope().find_var(
pos_enc_param_name).get_tensor()
pos_enc_param.set(
position_encoding_init(ModelHyperParams.max_length + 1,
ModelHyperParams.d_model), place)
for pass_id in xrange(TrainTaskConfig.pass_num):
pass_start_time = time.time()
for batch_id, data in enumerate(train_data()):
if len(data) != TrainTaskConfig.batch_size:
continue
data_input = prepare_batch_input(
data, encoder_input_data_names + decoder_input_data_names[:-1] +
label_data_names, ModelHyperParams.eos_idx,
ModelHyperParams.eos_idx, ModelHyperParams.n_head,
ModelHyperParams.d_model)
lr_scheduler.update_learning_rate(data_input)
outs = exe.run(fluid.framework.default_main_program(),
feed=data_input,
fetch_list=[sum_cost, avg_cost],
use_program_cache=True)
sum_cost_val, avg_cost_val = np.array(outs[0]), np.array(outs[1])
print("epoch: %d, batch: %d, sum loss: %f, avg loss: %f, ppl: %f" %
(pass_id, batch_id, sum_cost_val, avg_cost_val,
np.exp([min(avg_cost_val[0], 100)])))
# Validate and save the model for inference.
#val_avg_cost, val_ppl = test(exe)
pass_end_time = time.time()
time_consumed = pass_end_time - pass_start_time
print("pass_id = " + str(pass_id) + " time_consumed = " +
str(time_consumed))
#print("epoch: %d, val avg loss: %f, val ppl: %f, "
# "consumed %fs" % (pass_id, val_avg_cost, val_ppl, time_consumed))
fluid.io.save_inference_model(
os.path.join(TrainTaskConfig.model_dir,
"pass_" + str(pass_id) + ".infer.model"),
encoder_input_data_names + decoder_input_data_names[:-1],
[predict], exe)
if args.local:
def train_loop(exe, trainer_prog):
# Initialize the parameters.
"""
exe.run(fluid.framework.default_startup_program())
"""
for pos_enc_param_name in pos_enc_param_names:
pos_enc_param = fluid.global_scope().find_var(
pos_enc_param_name).get_tensor()
pos_enc_param.set(
position_encoding_init(ModelHyperParams.max_length + 1,
ModelHyperParams.d_model), place)
for pass_id in xrange(TrainTaskConfig.pass_num):
pass_start_time = time.time()
for batch_id, data in enumerate(train_reader()):
if len(data) != TrainTaskConfig.batch_size:
continue
data_input = prepare_batch_input(
data, encoder_input_data_names + decoder_input_data_names[:-1] +
label_data_names, ModelHyperParams.eos_idx,
ModelHyperParams.eos_idx, ModelHyperParams.n_head,
ModelHyperParams.d_model)
lr_scheduler.update_learning_rate(data_input)
outs = exe.run(trainer_prog,
feed=data_input,
fetch_list=[sum_cost, avg_cost],
use_program_cache=True)
sum_cost_val, avg_cost_val = np.array(outs[0]), np.array(outs[1])
print("epoch: %d, batch: %d, sum loss: %f, avg loss: %f, ppl: %f" %
(pass_id, batch_id, sum_cost_val, avg_cost_val,
np.exp([min(avg_cost_val[0], 100)])))
# Validate and save the model for inference.
#val_avg_cost, val_ppl = test(exe)
pass_end_time = time.time()
time_consumed = pass_end_time - pass_start_time
print("pass_id = " + str(pass_id) + " time_consumed = " +
str(time_consumed))
#print("epoch: %d, val avg loss: %f, val ppl: %f, "
# "consumed %fs" % (pass_id, val_avg_cost, val_ppl, time_consumed))
fluid.io.save_inference_model(
os.path.join(TrainTaskConfig.model_dir,
"pass_" + str(pass_id) + ".infer.model"),
encoder_input_data_names + decoder_input_data_names[:-1],
[predict], exe)
if args.local:
# Initialize the parameters.
exe.run(fluid.framework.default_startup_program())
#print("local start_up:")
......@@ -288,15 +281,15 @@ def main():
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.wmt16.train(ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size),
nist_data_provider.train("data", ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size),
buf_size=100000),
batch_size=args.batch_size)
batch_size=TrainTaskConfig.batch_size)
test_reader = paddle.batch(
paddle.dataset.wmt16.validation(ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size),
batch_size=args.batch_size)
nist_data_provider.train("data", ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size),
batch_size=TrainTaskConfig.batch_size)
train_loop(exe, fluid.default_main_program())
else:
......@@ -343,15 +336,15 @@ def main():
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.wmt16.train(ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size),
nist_data_provider.train("data", ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size),
buf_size=100000),
batch_size=args.batch_size)
batch_size=TrainTaskConfig.batch_size)
test_reader = paddle.batch(
paddle.dataset.wmt16.validation(ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size),
batch_size=args.batch_size)
nist_data_provider.train("data", ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size),
batch_size=TrainTaskConfig.batch_size)
trainer_prog = t.get_trainer_program()
train_loop(exe, trainer_prog)
......@@ -367,6 +360,4 @@ def print_arguments():
if __name__ == "__main__":
print_arguments()
main()
if __name__ == "__main__":
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册