提交 2946d312 编写于 作者: Y Yu Yang

Update

上级 29c47d15
......@@ -4,7 +4,7 @@ class TrainTaskConfig(object):
pass_num = 20
# the number of sequences contained in a mini-batch.
batch_size = 64
batch_size = 32
# the hyper parameters for Adam optimizer.
learning_rate = 0.001
......
......@@ -196,24 +196,26 @@ def main():
loss_name=avg_cost.name
if TrainTaskConfig.use_avg_cost else sum_cost.name)
local_scopes = train_exe.executor.local_scopes()
dev_count = fluid.core.get_cuda_device_count()
for pos_enc_param_name in pos_enc_param_names:
tensor = position_encoding_init(ModelHyperParams.max_length + 1, ModelHyperParams.d_model)
for place_id, local_scope in enumerate(local_scopes):
for place_id in xrange(dev_count):
local_scope = train_exe.executor.local_scope(place_id)
local_scope.find_var(pos_enc_param_name).get_tensor().set(tensor, fluid.CUDAPlace(place_id))
train_data = read_multiple(reader=train_data, count=dev_count)
for pass_id in xrange(TrainTaskConfig.pass_num):
pass_start_time = time.time()
for batch_id, data in enumerate(train_data()):
for place_id, data_buffer, local_scope in zip(range(len(data)), data, local_scopes):
for place_id, data_buffer in enumerate(data):
data_input_dict, util_input_dict = prepare_batch_input(
data_buffer, data_input_names, util_input_names,
ModelHyperParams.eos_idx, ModelHyperParams.eos_idx,
ModelHyperParams.n_head, ModelHyperParams.d_model)
local_scope = train_exe.executor.local_scope(place_id)
local_scope.find_var(lr_scheduler.learning_rate.name).get_tensor().set(
lr_scheduler.update_learning_rate(),
fluid.CUDAPlace(place_id))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册