提交 3e51633d 编写于 作者: G gongweibao

cleanup

上级 844c7bec
......@@ -10,7 +10,6 @@ from config import TrainTaskConfig, pos_enc_param_names, \
# FIXME(guosheng): Remove out the batch_size from the model.
batch_size = TrainTaskConfig.batch_size
def position_encoding_init(n_position, d_pos_vec):
"""
Generate the initial values for the sinusoid position encoding table.
......
......@@ -8,6 +8,7 @@ import paddle.fluid as fluid
import paddle.fluid.core as core
from model import transformer, position_encoding_init
import model
from optim import LearningRateScheduler
from config import TrainTaskConfig, ModelHyperParams, pos_enc_param_names, \
encoder_input_data_names, decoder_input_data_names, label_data_names
......@@ -71,6 +72,8 @@ parser.add_argument(
"--task_index", type=int, default=0, help="Index of task within the job")
args = parser.parse_args()
model.batch_size = args.batch_size
def pad_batch_data(insts,
pad_idx,
n_head,
......@@ -118,7 +121,6 @@ def pad_batch_data(insts,
def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
max_length, n_head):
print("input_data_name:", input_data_names)
"""
Put all padded data needed by training into a dict.
"""
......@@ -152,7 +154,6 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape,
lbl_word, lbl_weight
]))
#print("input_dict", input_dict)
return input_dict
......@@ -199,7 +200,7 @@ def main():
ModelHyperParams.trg_pad_idx, ModelHyperParams.max_length,
ModelHyperParams.n_head)
test_cost = exe.run(test_program,
test_cost = exe.run(inference_program,
feed=data_input,
fetch_list=[cost])[0]
......@@ -210,7 +211,6 @@ def main():
ts = time.time()
for pass_id in xrange(args.pass_num):
for batch_id, data in enumerate(train_reader()):
print("batch_id:", batch_id)
# The current program desc is coupled with batch_size, thus all
# mini-batches must have the same number of instances currently.
if len(data) != args.batch_size:
......@@ -223,11 +223,8 @@ def main():
ModelHyperParams.trg_pad_idx, ModelHyperParams.max_length,
ModelHyperParams.n_head)
#print("feed0:", data_input)
#print("fetch_list0:", [cost])
lr_scheduler.update_learning_rate(data_input)
print("before exe run in train_loop")
outs = exe.run(trainer_prog,
feed=data_input,
fetch_list=[cost],
......@@ -239,9 +236,7 @@ def main():
# Validate and save the model for inference.
val_cost = test(exe)
#pass_elapsed = time.time() - start_time
#print("pass_id = " + str(pass_id) + " val_cost = " + str(val_cost))
print("pass_id = %d batch = %d cost = %f speed = %.2f sample/s" %
print("pass_id = %d cost = %f avg_speed = %.2f sample/s" %
(pass_id, batch_id, cost_val, len(data) / (time.time() - ts)))
if args.local:
......@@ -298,9 +293,6 @@ def main():
exe.run(pserver_startup)
exe.run(pserver_prog)
elif training_role == "TRAINER":
#print("cost 0:", cost)
#print("before run start up")
# Parameter initialization
exe.run(fluid.default_startup_program())
......@@ -327,13 +319,7 @@ def main():
ModelHyperParams.trg_vocab_size),
batch_size=args.batch_size)
#print("before get trainer program")
trainer_prog = t.get_trainer_program()
#print("before start")
# feeder = fluid.DataFeeder(feed_list=[images, label], place=place)
# TODO(typhoonzero): change trainer startup program to fetch parameters from pserver
# exe.run(fluid.default_startup_program())
train_loop(exe, trainer_prog)
else:
print("environment var TRAINER_ROLE should be TRAINER os PSERVER")
......
......@@ -28,7 +28,6 @@ class LearningRateScheduler(object):
dtype="float32",
persistable=True)
self.place = place
#print("LearningRateScheduler init learning_rate_name:", self.learning_rate.name)
def update_learning_rate(self, data_input):
self.current_steps += 1
......@@ -38,7 +37,4 @@ class LearningRateScheduler(object):
])
lr_tensor = fluid.LoDTensor()
lr_tensor.set(np.array([lr_value], dtype="float32"), self.place)
#print("in learning_rate")
#print("learning_rate_name:", self.learning_rate.name)
#print("data_input:", data_input)
data_input[self.learning_rate.name] = lr_tensor
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册